# 目标是构建一个可供训练和测试的数据集

import os
import pandas as pd
from sklearn.model_selection import train_test_split

MOVIE_PATH = "../../data/ml-1m/movies.dat"
RATING_PATH = "../../data/ml-1m/ratings.dat"
USERS_PATH = "../../data/ml-1m/users.dat"
SET_PATH = "../../data/ml-1m_20190508"
MOVIE_RATING_PATH = "%s/rating_combine_20190506.csv" % SET_PATH
TRAIN_RATING_PATH = "%s/TRAIN_20190506.csv" % SET_PATH
TEST_RATING_PATH = "%s/TEST_20190506.csv" % SET_PATH


if not os.path.exists(SET_PATH):
    os.makedirs(SET_PATH)

# -----------------------------第一手数据处理---------------------------

# 读取数据
movies = pd.read_csv(MOVIE_PATH, sep="::", header=None, names=[
                     "movieId", "movieName", "genres"], engine='python')
users = pd.read_csv(USERS_PATH, sep="::", header=None, names=[
                    "userId", "gender", "age", "occupation", "zipcode"], engine='python')
rating = pd.read_csv(RATING_PATH, sep="::", header=None, names=[
                     "userId", "movieId", "rating", "timestamp"], engine='python')

# 数据合并
data = pd.merge(movies, rating, on="movieId")
data = pd.merge(data, users, on="userId")

# 信息组合
data = data[["rating", "movieId", "movieName", "genres", "userId", "gender",
             "age", "occupation"]]
data.to_csv(MOVIE_RATING_PATH, index=False, sep="@")

# 训练集和测试集组合
X_train, X_test, y_train, y_test = train_test_split(data[["rating"]], data[[
                                                    "movieId", "movieName", "genres", "userId", "gender", "age", "occupation"]], test_size=0.33, random_state=10)
train_set = y_train.join(X_train)[
    ["rating", "movieId", "movieName", "genres", "userId", "gender", "age", "occupation"]]
test_set = y_test.join(X_test)[
    ["rating", "movieId", "movieName", "genres", "userId", "gender", "age", "occupation"]]
train_set.to_csv(TRAIN_RATING_PATH, index=False, sep="@")
test_set.to_csv(TEST_RATING_PATH, index=False, sep="@")

# 信息重组
data = []
with open(MOVIE_RATING_PATH, encoding = "utf8") as f:
    for line in f:
        data.append(line.strip().split("@"))
with open(MOVIE_RATING_PATH, "w", encoding = "utf8") as f:
    for item in data:
        f.write("%s\n" % "::".join(item))

data = []
with open(TRAIN_RATING_PATH, encoding = "utf8") as f:
    for line in f:
        data.append(line.strip().split("@"))
with open(TRAIN_RATING_PATH, "w", encoding = "utf8") as f:
    for item in data:
        f.write("%s\n" % "::".join(item))

data = []
with open(TEST_RATING_PATH, encoding = "utf8") as f:
    for line in f:
        data.append(line.strip().split("@"))
with open(TEST_RATING_PATH, "w", encoding = "utf8") as f:
    for item in data:
        f.write("%s\n" % "::".join(item))
